{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Structured prediction\n",
    "\n",
    "\n",
    "\n",
    "In this example$\\newcommand{\\reals}{\\mathbf{R}}$$\\newcommand{\\ones}{\\mathbf{1}}$, we fit a regression model to structured data, using an LLCP.\n",
    "The training dataset $\\mathcal D$ contains $N$ input-output pairs $(x, y)$,\n",
    "where $x \\in \\reals^{n}_{++}$ is an input and $y \\in \\reals^{m}_{++}$ is an\n",
    "outputs. The entries of each output $y$ are sorted in ascending order, meaning\n",
    "$y_1 \\leq y_2 \\leq \\cdots y_m$.\n",
    "\n",
    "Our regression model $\\phi : \\reals^{n}_{++} \\to \\reals^{m}_{++}$ takes as\n",
    "input a vector $x \\in \\reals^{n}_{++}$, and solves an LLCP to produce a\n",
    "prediction $\\hat y \\in \\reals^{m}_{++}$. In particular, the solution of the\n",
    "LLCP is model's prediction. The model is of the form\n",
    "$$\n",
    "\\begin{equation}\n",
    "\\begin{array}{lll}\n",
    "\\phi(x) = &\n",
    "\\mbox{argmin} & \\ones^T (z/y + y / z) \\\\\n",
    "& \\mbox{subject to} &  y_i \\leq y_{i+1}, \\quad i=1, \\ldots, m-1 \\\\\n",
    "&& z_i = c_i x_1^{A_{i1}}x_2^{A_{i2}}\\cdots x_n^{A_{in}}, \\quad i = 1, \\ldots, m.\n",
    "\\end{array}\\label{e-model}\n",
    "\\end{equation}\n",
    "$$\n",
    "Here, the minimization is over $y \\in \\reals^{m}_{++}$ and an auxiliary\n",
    "variable $z \\in \\reals^{m}_{++}$, $\\phi(x)$ is the optimal value of $y$, and\n",
    "the parameters are $c \\in \\reals^{m}_{++}$ and $A \\in \\reals^{m \\times n}$. The\n",
    "ratios in the objective are meant elementwise, as is the inequality $y \\leq z$, and\n",
    "$\\ones$ denotes the vector of all ones. Given a vector $x$, this model finds a\n",
    "sorted vector $\\hat y$ whose entries are close to monomial functions of $x$\n",
    "(which are the entries of $z$), as measured by the fractional error.\n",
    "\n",
    "\n",
    "The training loss\n",
    "$\\mathcal{L}(\\phi)$ of the model on the training set is the mean squared loss\n",
    "$$\n",
    "\\mathcal{L}(\\phi) = \\frac{1}{N}\\sum_{(x, y) \\in \\mathcal D} \\|y - \\phi(x)\\|_2^2.\n",
    "$$\n",
    "We emphasize that $\\mathcal{L}(\\phi)$ depends on $c$ and $A$.\n",
    "In this example, we fit the parameters $c$ and $A$ in the LLCP \n",
    "to minimize the training loss $\\mathcal{L}(\\phi)$.\n",
    "\n",
    "**Fitting.** We fit the parameters by an iterative projected\n",
    "gradient descent method on $\\mathcal L(\\phi)$. In each iteration, we first\n",
    "compute predictions $\\phi(x)$ for each input in the training set; this requires\n",
    "solving $N$ LLCPs. Next, we evaluate the training loss $\\mathcal L(\\phi)$. To\n",
    "update the parameters, we compute the gradient $\\nabla \\mathcal L(\\phi)$ of the\n",
    "training loss with respect to the parameters $c$ and $A$. This\n",
    "requires differentiating through the solution map of the LLCP. We can compute this gradient efficiently, using the ``backward`` method in CVXPY\n",
    "(or CVXPY Layers). Finally, we subtract\n",
    "a small multiple of the gradient from the parameters. Care must be taken to\n",
    "ensure that $c$ is strictly positive; this can be done by clamping the entries\n",
    "of $c$ at some small threshold slightly above zero. We run this method for\n",
    "a fixed number of iterations.\n",
    "\n",
    "This example is described in the paper [Differentiating through Log-Log Convex Programs](http://web.stanford.edu/~boyd/papers/pdf/diff_llcvx.pdf).\n",
    "\n",
    "Shane Barratt formulated the idea of using an optimization layer to regress on sorted vectors.\n",
    "\n",
    "**Requirements.**\n",
    "This example requires PyTorch and CvxpyLayers >= v0.1.3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cvxpylayers.torch import CvxpyLayer\n",
    "\n",
    "\n",
    "import cvxpy as cp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 20\n",
    "m = 10\n",
    "\n",
    "# Number of training input-output pairs\n",
    "N = 100\n",
    "\n",
    "# Number of validation pairs\n",
    "N_val = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.random.manual_seed(243)\n",
    "np.random.seed(243)\n",
    "\n",
    "normal = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(n), torch.eye(n))\n",
    "lognormal = lambda batch: torch.exp(normal.sample(torch.tensor([batch])))\n",
    "\n",
    "A_true = torch.randn((m, n)) / 10\n",
    "c_true = np.abs(torch.randn(m))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(num_points, seed):\n",
    "    torch.random.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "    latent = lognormal(num_points)\n",
    "    noise = lognormal(num_points)\n",
    "    inputs = noise + latent\n",
    "\n",
    "    input_cp = cp.Parameter(pos=True, shape=(n,))\n",
    "    prediction = cp.multiply(c_true.numpy(), cp.gmatmul(A_true.numpy(), input_cp))\n",
    "    y = cp.Variable(pos=True, shape=(m,))\n",
    "    objective_fn = cp.sum(prediction / y + y/prediction)\n",
    "    constraints = []\n",
    "    for i in range(m-1):\n",
    "        constraints += [y[i] <= y[i+1]]\n",
    "    problem = cp.Problem(cp.Minimize(objective_fn), constraints)\n",
    "    \n",
    "    outputs = []\n",
    "    for i in range(num_points):\n",
    "        input_cp.value = inputs[i, :].numpy()\n",
    "        problem.solve(cp.SCS, gp=True)\n",
    "        outputs.append(y.value)\n",
    "    return inputs, torch.stack([torch.tensor(t) for t in outputs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x12b367cd0>]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAZ1klEQVR4nO3dfXRU933n8fdXEuJZiAcBRhISTrCxgh+whGvH29iNnY1de6GN4yzsNo3d1N7klMTppmmdnNbJcTbJ2XOadLNn2T11HTvPoTZNWtrSpdm1025y3EQzgB8AQ1Q8gyQwCDQS4kmP3/1DI3sYhDRCM7pzrz6vc3TO3Ds/3fudC/ro6ve793fN3RERkfArCboAERHJDwW6iEhEKNBFRCJCgS4iEhEKdBGRiCgLasdLlizx+vr6oHYvIhJK8Xj8pLtXjfZeYIFeX19PLBYLavciIqFkZsnLvacuFxGRiFCgi4hEhAJdRCQiFOgiIhGhQBcRiYicAt3M7jGzg2bWYmaPj/J+nZn9XzN7xcx+YmY1+S9VRETGMm6gm1kpsBW4F2gANptZQ1azPwW+7e43AE8CX8l3oSIiMrZcztBvAVrc/bC79wHbgI1ZbRqAF9KvXxzlfRGRaW9oyPnyzgO80tZVkO3nEujVQGvGclt6XaaXgQ+kX/8mMN/MFmdvyMweNbOYmcU6OjqupF4RkdA6dKKHp/75ML88fqYg28/XoOgfAHeY2R7gDqAdGMxu5O5PuXuTuzdVVY1656qISGTFEikAmuoXFmT7udz63w7UZizXpNe9xd2Pkj5DN7N5wAPuXpi/KUREQiqeTFE1fyYrF80pyPZzOUNvBlab2SozKwc2ATsyG5jZEjMb2dZngWfyW6aISPg1JzppqluImRVk++MGursPAFuAXcAB4Dl332dmT5rZhnSzO4GDZnYIWAZ8qSDVioiE1JvdF2hLnaepflHB9pHTbIvuvhPYmbXuiYzX24Ht+S1NRCQ6YslOAJrqCtN/DrpTVERkSsQSKWbPKKVhRUXB9qFAFxGZArFkJzfVVjKjtHCxq0AXESmwM70D7D96umCXK45QoIuIFNjeI10MOQUdEAUFuohIwcWSnZjBupWVBd2PAl1EpMDiyRRrlldQMWtGQfejQBcRKaCBwSF2J1MFvVxxhAJdRKSAXn+zh7N9gwUfEAUFuohIQcUS6RuKCjwgCgp0EZGCiiVTXLVgFtWVswu+LwW6iEiBuDuxRGpKzs5BgS4iUjDtXed58/SFKRkQBQW6iEjBFPqBFtkU6CIiBRJLdjJvZhlrlhduQq5MCnQRkQKJJVKsW1lJaUlhHmiRTYEuIlIA3ef7OXi8h6a6qRkQBQW6iEhB7DmSwh3WT1H/OSjQRUQKIpZIUVpi3FTgCbkyKdBFRAogluzkXSsqmFOe05M+80KBLiKSZ/2DQ+xt7aJxiq4/H6FAFxHJs31HT3Ohf2hKB0RBgS4ikndvT8ilM3QRkVCLJVLULprNsopZU7pfBbqISB65O7FkJ+unuLsFFOgiInmVPHWOk2f6aJzi7hZQoIuI5FUsOTwh1/opmjI3U06Bbmb3mNlBM2sxs8dHeX+lmb1oZnvM7BUz+/X8lyoiUvxiiU4qZpXxzqp5U77vcQPdzEqBrcC9QAOw2cwaspr9MfCcu68DNgH/M9+FioiEQSw5/ECLkimakCtTLmfotwAt7n7Y3fuAbcDGrDYOjMwPuQA4mr8SRUTCIXW2j5YTZ6b8hqIRuQR6NdCasdyWXpfpC8BvmVkbsBP4xGgbMrNHzSxmZrGOjo4rKFdEpHjF0/3nU/WEomz5GhTdDHzT3WuAXwe+Y2aXbNvdn3L3JndvqqqqytOuRUSKQ3Oykxmlxo21UzchV6ZcAr0dqM1Yrkmvy/RR4DkAd38JmAUsyUeBIiJhEU+kWFu9gFkzSgPZfy6B3gysNrNVZlbO8KDnjqw2R4C7AMzsOoYDXX0qIjJtXOgf5JW27kAuVxwxbqC7+wCwBdgFHGD4apZ9ZvakmW1IN/s08IiZvQz8AHjI3b1QRYuIFJvX2rvpGxwKbEAUIKeJet19J8ODnZnrnsh4vR+4Pb+liYiER3Mi2AFR0J2iIiJ5EU92cvWSuSyeNzOwGhToIiKTNDTkxJOpKZ8uN5sCXURkkg6fPEPqXP+UP9AimwJdRGSSYun+8yBmWMykQBcRmaTmRIpFc8u5esncQOtQoIuITFI82Ulj3ULMpn5CrkwKdBGRSejo6SVx6hzrA+5uAQW6iMikxJPDD4RuDHhAFBToIiKT0pxIMbOshLXVFeM3LjAFuojIJMSSKW6sqWRmWTATcmVSoIuIXKHzfYPsa+8O/IaiEQp0EZErtLe1i4EhV6CLiITdWwOiK4MfEAUFuojIFWtOpLhm2TwWzJkRdCmAAl1E5IoMDjm7j6SK4nLFEQp0EZErcOh4Dz0XBorihqIRCnQRkSsQS4480EJn6CIioRZLdLJ0/kxqF80OupS3KNBFRK5ALDH8QIugJ+TKpEAXEZmgY93nae86X1TdLaBAFxGZsJEHWhTLDUUjFOgiIhMUT6aYU15Kw1XBT8iVSYEuIjJBzYlObqqtpKy0uCK0uKoRESlyZ3oHOHDsNE11xdXdAgp0EZEJ2XMkxZBDU31xDYiCAl1EZEJiiRQlButWVgZdyiVyCnQzu8fMDppZi5k9Psr7f2Zme9Nfh8ysK/+liogEL5bsZM3yCubPKo4JuTKVjdfAzEqBrcD7gDag2cx2uPv+kTbu/vsZ7T8BrCtArSIigRoYHGLPkS4+2FgTdCmjyuUM/Ragxd0Pu3sfsA3YOEb7zcAP8lGciEgxOXCsh3N9g0XZfw65BXo10Jqx3JZedwkzqwNWAS9MvjQRkeISSz/QohivcIH8D4puAra7++Bob5rZo2YWM7NYR0dHnnctIlJYsUSK6srZrKgsngm5MuUS6O1AbcZyTXrdaDYxRneLuz/l7k3u3lRVVZV7lSIiAXN3YslOGov07BxyC/RmYLWZrTKzcoZDe0d2IzNbAywEXspviSIiwWtLnef46d6im78l07iB7u4DwBZgF3AAeM7d95nZk2a2IaPpJmCbu3thShURCc7b/efFOSAKOVy2CODuO4GdWeueyFr+Qv7KEhEpLrFEivkzy7h2+fygS7ks3SkqIpKDWCLFurqFlJYUzwMtsinQRUTG0X2un0Mneor2csURCnQRkXHsPpLCvfgeaJFNgS4iMo5YspPSEuOm2uKbkCuTAl1EZBzNiRRrV1Qwpzyn60gCo0AXERlD38AQL7d20VjElyuOUKCLiIzhtaPd9A4MFX3/OSjQRUTGFE+kgOKdkCuTAl1EZAyxZCcrF81hacWsoEsZlwJdROQy3J1YIhWK7hZQoIuIXFbi1DlOne0r6vlbMinQRUQuozkxPCHXep2hi4iEWzyRYsHsGbyjal7QpeREgS4ichnNyU6a6hZSUsQTcmVSoIuIjOLUmV4Od5ylMSTdLaBAFxEZVTw5fP35+vpwDIiCAl1EZFTxZIry0hKur14QdCk5U6CLiIwilkyxtrqCWTNKgy4lZwp0EZEsF/oHebWtO1TdLaBAFxG5xKvt3fQNDtEYgvlbMinQRUSyjNxQpEAXEQm5eCLF1VVzWTxvZtClTIgCXUQkw9CQE0umWB+S+VsyKdBFRDL8a8cZus/3h+qGohEKdBGRDM2J8N1QNEKBLiKSIZbsZPHccuoXzwm6lAlToIuIZIglUjTWLcQsHBNyZcop0M3sHjM7aGYtZvb4Zdp8yMz2m9k+M/t+fssUESm8Ez0XONJ5LpTdLQBl4zUws1JgK/A+oA1oNrMd7r4/o81q4LPA7e6eMrOlhSpYRKRQRh4IHcYBUcjtDP0WoMXdD7t7H7AN2JjV5hFgq7unANz9RH7LFBEpvOZEipllJaxdEZ4JuTLlEujVQGvGclt6XaZrgGvM7Gdm9i9mds9oGzKzR80sZmaxjo6OK6tYRKRA4slObqytpLwsnMOL+aq6DFgN3AlsBv7CzCqzG7n7U+7e5O5NVVVVedq1iMjknesb4LWjp0Pz/NDR5BLo7UBtxnJNel2mNmCHu/e7+xvAIYYDXkQkFPa2djE45DSF8A7REbkEejOw2sxWmVk5sAnYkdXmrxk+O8fMljDcBXM4j3WKiBRULJHCDG5eGeEzdHcfALYAu4ADwHPuvs/MnjSzDelmu4BTZrYfeBH4jLufKlTRIiL5FkumuGbpfBbMmRF0KVds3MsWAdx9J7Aza90TGa8d+M/pLxGRUBkccnYnU2y4aUXQpUxKOIdyRUTy6OCbPZzpHQj1gCgo0EVEiCeHH2gR5gFRUKCLiNCcSLGsYiY1C2cHXcqkKNBFZNqLJ1M01S0K5YRcmRToIjKtHe06T3vXeZpC3n8OCnQRmeZiyeEJucLefw4KdBGZ5mKJTuaUl3LdVfODLmXSFOgiMq3FEinWraykrDT8cRj+TyAicoV6LvTz+punI9HdAgp0EZnG9hzpYsiJxIAoKNBFZBqLJTopMVgX4gm5MinQRWTaiiVTXHdVBfNm5jStVdFToIvItNQ/OMTe1i6a6qJxdg4KdBGZpg4cO825vkGa6qMxIAoKdBGZpmKJ9A1FERkQBQW6iExTsWQn1ZWzuWpBuCfkyqRAF5Fpx92JJVKROjsHBbqITEOtnec50dMbqf5zUKCLyDQUe+uBFjpDFxEJteZEivkzy7hmWfgn5MqkQBeRaSee7OTmuoWUloT7gRbZFOgiMq10nevj0PEzketuAQW6iEwzu4+MXH8erQFRUKCLyDQTS6QoKzFuqq0MupS8U6CLyLQSS6R4V/UCZpeXBl1K3inQRWTa6B0Y5OW2aE3IlSmnQDeze8zsoJm1mNnjo7z/kJl1mNne9Nfv5r9UEZHJea39NL0DQ6yP2B2iI8adBNjMSoGtwPuANqDZzHa4+/6spn/p7lsKUKOISF7E0zcUNUbkkXPZcjlDvwVocffD7t4HbAM2FrYsEZH8a06kqFs8h6r5M4MupSByCfRqoDVjuS29LtsDZvaKmW03s9rRNmRmj5pZzMxiHR0dV1CuiMiVcXfiyVRkHgg9mnwNiv4tUO/uNwA/Br41WiN3f8rdm9y9qaqqKk+7FhEZ3+GTZ+k82xe5GRYz5RLo7UDmGXdNet1b3P2Uu/emF58GGvNTnohIfsTTD7SI6oAo5BbozcBqM1tlZuXAJmBHZgMzuypjcQNwIH8liohMXizZSeWcGVy9ZF7QpRTMuFe5uPuAmW0BdgGlwDPuvs/MngRi7r4D+KSZbQAGgE7goQLWLCIyIef7BvlZyyma6hZSErEJuTKNG+gA7r4T2Jm17omM158FPpvf0kREJm9wyHls2x6Odp/nS7+5NuhyCkp3iopIpH1l5wH+cf9xnri/gTuvXRp0OQWlQBeRyPrOSwme/ukbPPTueh6+fVXQ5RScAl1EIumF14/z+R37uPu6pfzJ/Q1BlzMlFOgiEjmvtXez5ft7aFhRwdc3rYvck4kuR4EuIpFyrPs8H/1WM5WzZ/CNj6xn7sycrv2IhOnzSUUk8s70DvA734xxtneQ7R+/jWUVs4IuaUop0EUkEgYGh/i97+3m0PEenn1oPWuWVwRd0pRTl4uIhJ678/kd+/inQx38l99Yy3uumZ5zRSnQRST0nv5/b/C9nx/hY3e8g823rAy6nMAo0EUk1P7h1WN8+R8OcN/1V/GH77826HICpUAXkdDacyTFp/5yL+tqK/nqh26M9DwtuVCgi0gotXae45Fvx1hWMYu/+O0mZs0oDbqkwCnQRSR0us/18/A3m+kfdJ55aD2L50XzkXITpUAXkVDpGxjiY9+Nkzx1lj//cCPvXBrd+c0nStehi0houDuf+9GrvHT4FF/70I3cevXioEsqKjpDF5HQ+B8vtLA93sZjd63mAzfXBF1O0VGgi0go/M3edr7640N8YF01n7p7ddDlFCUFuogUvV+80clnnn+FX1m1iK88cD1m0/vyxMtRoItIUTvccYZHvxOjZtFs/vzDjcws0+WJl6NAF5Gi1Xm2j9/5ZjMlZjz70Hoq55QHXVJR01UuIlKULvQP8si3YxztvsAPHrmVusVzgy6p6OkMXUSKztCQ8wfPv0w8meLPPnQTjXULgy4pFBToIlJ0vvrjg/zdK8f4o3vWcN8NVwVdTmgo0EWkqDzX3MrWF/+VzbfU8rE7rg66nFBRoItI0fjpL0/yuR+9yq+uXsKTG9fq8sQJUqCLSFE4dLyHj383zjuq5rH1P97MjFLF00TldMTM7B4zO2hmLWb2+BjtHjAzN7Om/JUoIlF3oucCDz/bzKzyUp55eD0Vs2YEXVIojRvoZlYKbAXuBRqAzWbWMEq7+cBjwM/zXaSIRNe5vgF+91sxOs/28cxH1lNdOTvokkIrlzP0W4AWdz/s7n3ANmDjKO2+CPxX4EIe6xORCBscch7btpdX27v575vXcX3NgqBLCrVcAr0aaM1Ybkuve4uZ3QzUuvvfj7UhM3vUzGJmFuvo6JhwsSISLV/eeYAf7z/OE/c38L6GZUGXE3qTHnUwsxLga8Cnx2vr7k+5e5O7N1VVVU121yISYt9+KcE3fvoGD727nodvXxV0OZGQS6C3A7UZyzXpdSPmA2uBn5hZArgV2KGBURG5nBdeP84Xduzj7uuW8if3XzIkJ1col0BvBlab2SozKwc2ATtG3nT3bndf4u717l4P/Auwwd1jBalYRELttfZutnx/Dw0rKvj6pnWUluha83wZN9DdfQDYAuwCDgDPufs+M3vSzDYUukARiY5j3ef56LeaqZw9g298ZD1zZ2p+wHzK6Wi6+05gZ9a6Jy7T9s7JlyUiUdNzoZ+Hn23mbO8g2z9+G8sqZgVdUuTo16OIFNzA4BBbvr+HX544wzMPrWfN8oqgS4ok3VsrIgXl7nx+xz7+6VAHX9y4ljuu0RVuhaIz9Ctw/PQFfrSnndePnb5sm7EmFRpzCGiMN+0yb050/qKJNJ/Iti9X30Tkay4m94zX+EXrfIw2XNLGM9pc/H1jtRlN9kcb7f/IpW0mto2JHr6x6oWLP9uVbuP0+X5ePNjBf7rjav7Dr6zMuTaZOAV6jnoHBvk/+0/wfLyVfz7UwZBDzcLZo47Qj/Uz4GP89x/z+8b/ucpJLj+gb7Wd0HYnXsul+5v8RtzfDsHMXzBvrxtZHiVMLbutXbQ82vdb9huMHqqXfLJRPmr2qux/q0vfz/7+SzeaeTwuJ5dfxONvY2y/fVsdf/T+NePuRyZHgT4Gd2ff0dM8H2vlb14+Ste5fpZXzOLjd76DDzbWsmqJHoklIsVDgT6KU2d6+eu9R3k+1srrb/ZQXlbCv21YxoNNtfybdy7RdbMiUpQU6GkDg0P85GAHz8dbeeH1E/QPOjfULOCLG9/FhhurWTBH03mKSHGb9oH+y+M9PB9v44e72zl5ppcl88r5yG31PNhUy7XL5wddnohIzqZloHef7+dvXz7K8/E2Xm7toqzE+LU1S3mwsYZfW7NUT0oRkVCaNoE+OOT8rOUk2+Nt7Nr3Jr0DQ1y7bD5/fN91/Ma6apbMmxl0iSIikxL5QE+cPMv2eBs/3N3G0e4LLJg9g3+/vpYHG2tZW12hh9CKSGREMtDP9g7w968eY3usjV8kOikx+NXVVXzuvuu4+7plzJpRGnSJIiJ5F5lAd3d+8UYnz8fb2PnqMc71DbJqyVw+8/5reeDmGpYv0ERAIhJtoQ/0o13n+at4G9t3t5E8dY655aX8uxtW8GBTDY11C9WlIiLTRigD/UL/ILv2vcn2eBs/bTmJO9x69SI++d7V3Hv9cuaUh/JjiYhMSuiSb9svjvClnQfouTBAdeVsPvHe1Xzw5hpWLp4TdGkiIoEKXaCvqJzNXWuW8mBTLbddvZgS3YYvIgKEMNDfc00V79F8yiIil9AtkSIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhEKdBGRiFCgi4hEhAJdRCQizN2D2bFZB5C8wm9fApzMYzlhp+NxMR2Pt+lYXCwKx6PO3Ue9uzKwQJ8MM4u5e1PQdRQLHY+L6Xi8TcfiYlE/HupyERGJCAW6iEhEhDXQnwq6gCKj43ExHY+36VhcLNLHI5R96CIicqmwnqGLiEgWBbqISESELtDN7B4zO2hmLWb2eND1BMXMas3sRTPbb2b7zOyxoGsqBmZWamZ7zOzvgq4laGZWaWbbzex1MztgZrcFXVNQzOz30z8nr5nZD8xsVtA1FUKoAt3MSoGtwL1AA7DZzBqCrSowA8Cn3b0BuBX4vWl8LDI9BhwIuogi8XXgf7v7GuBGpulxMbNq4JNAk7uvBUqBTcFWVRihCnTgFqDF3Q+7ex+wDdgYcE2BcPdj7r47/bqH4R/W6mCrCpaZ1QD3AU8HXUvQzGwB8B7gGwDu3ufuXcFWFagyYLaZlQFzgKMB11MQYQv0aqA1Y7mNaR5iAGZWD6wDfh5sJYH7b8AfAkNBF1IEVgEdwLPpLqinzWxu0EUFwd3bgT8FjgDHgG53/8dgqyqMsAW6ZDGzecBfAZ9y99NB1xMUM7sfOOHu8aBrKRJlwM3A/3L3dcBZYFqOOZnZQob/kl8FrADmmtlvBVtVYYQt0NuB2ozlmvS6acnMZjAc5t9z9x8GXU/Abgc2mFmC4a6495rZd4MtKVBtQJu7j/zVtp3hgJ+O7gbecPcOd+8Hfgi8O+CaCiJsgd4MrDazVWZWzvDAxo6AawqEmRnD/aMH3P1rQdcTNHf/rLvXuHs9w/8vXnD3SJ6F5cLd3wRazeza9Kq7gP0BlhSkI8CtZjYn/XNzFxEdIC4LuoCJcPcBM9sC7GJ4pPoZd98XcFlBuR34MPCqme1Nr/ucu+8MsCYpLp8Avpc++TkMPBxwPYFw95+b2XZgN8NXh+0holMA6NZ/EZGICFuXi4iIXIYCXUQkIhToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEf8fxruOmkdFgNkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_inputs, train_outputs = generate_data(N, 243)\n",
    "plt.plot(train_outputs[0, :].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x12da7e410>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAXmElEQVR4nO3da3Bc933e8e8DkOD9IpKwZJGQCF4UmXYkS0EoySQlj2NnJDtDZSZNhppJWnecKHJDx43Ti+x2NB11+qJJx01fcKzQrjKd8YVVFE+GbeiwncaysdSN0MWySVreJSiRoMRoF7yJpEjcfn2BhbQEF8CC3MXZPft8ZjiDc85/z/lhSTw8+J9zfquIwMzMGl9L0gWYmVl1ONDNzFLCgW5mlhIOdDOzlHCgm5mlxKykDrxixYpYvXp1Uoc3M2tIL730UiEi2sttSyzQV69eTU9PT1KHNzNrSJLenGibp1zMzFLCgW5mlhIOdDOzlHCgm5mlhAPdzCwlHOhmZinhQDczSwkHupnZDBkZCf7T3x3kp31narJ/B7qZ2Qw5+PZZvtl9hF/847s12b8D3cxshmRyBQA2r19Rk/070M3MZkgmW+CXrl/E9Yvn1mT/DnQzsxlwcXCYF984WbOzc3Cgm5nNiBePnGRgaMSBbmbW6DK5Am2tLdzVuaxmx3Cgm5nNgO5sgTtvXsr8ttp1LXegm5nVWP7dSxx6+yxb1pf9XIqqqSjQJd0v6XVJOUmPltn+XyW9WvzzC0mnq1+qmVljevZw8XbFdbWbP4cKPrFIUiuwA/gM0Afsl7Q7Ig6OjYmIPykZ/yXgjhrUambWkLqzBZbMm83HVi6p6XEqOUPfCOQiojciBoBdwIOTjH8I+F41ijMza3QRQSZbYNO65bS2qKbHqiTQVwLHSpb7iuuuIOlmoBP4hwm2PyypR1JPPp+fbq1mZg3ncP4cJ85erPn8OVT/oug24OmIGC63MSJ2RkRXRHS1t9f+mzMzS1p3dmbmz6GyQD8OdJQsryquK2cbnm4xM3tfd7bA6uXz6Vg2v+bHqiTQ9wPrJXVKamM0tHePHyTpVuA64Lnqlmhm1pgGhkZ4vre/pk+Hlpoy0CNiCNgO7AUOAU9FxAFJj0vaWjJ0G7ArIqI2pZqZNZZXjp7iwsAwm9fNzBRzRY8sRcQeYM+4dY+NW/4P1SvLzKzxZXIFWgT3rF0+I8fzk6JmZjXSnS1we8dSlsybPSPHc6CbmdXAmQuDvNZ3mi0zcHfLGAe6mVkNPNdbYCRg8wzcfz7GgW5mVgPd2QIL2lq546alM3ZMB7qZWQ1kcgXuWbuc2a0zF7MOdDOzKjvaf4E3+y/MyNOhpRzoZmZV1p0b7VU1k/Pn4EA3M6u6TLbAh5fMZW37ghk9rgPdzKyKhkeCZw/3s3ndCqTatssdz4FuZlZFPz1+hjPvDc5Y/5ZSDnQzsyrKZEfnzzfN8AVRcKCbmVVVd7bAhg8vZsXCOTN+bAe6mVmVnL80xMtHT7Hllpk/OwcHuplZ1bx45CSDw8GWGWqXO54D3cysSrqzBebMaqFr9XWJHN+BbmZWJd3ZPBs7lzF3dmsix3egm5lVwYkzF8m+c27GH/cv5UA3M6uCTK4AkMj952Mc6GZmVZDJ5lm+oI2P3LA4sRoc6GZm1ygiyOT62bRuBS0tM/u4fykHupnZNfr5iXcpnLuU6HQLVBjoku6X9LqknKRHJxjzO5IOSjog6bvVLdPMrH5lsqPz51sSDvRZUw2Q1ArsAD4D9AH7Je2OiIMlY9YDXwU2RcQpSR+qVcFmZvWmO1dg3YcW8uEl8xKto5Iz9I1ALiJ6I2IA2AU8OG7MHwA7IuIUQES8U90yzczq08XBYV480p/o7YpjKgn0lcCxkuW+4rpStwC3SNon6XlJ95fbkaSHJfVI6snn81dXsZlZHXnpzVNcHBxJfLoFqndRdBawHvgk8BDwTUlXfNR1ROyMiK6I6GpvT6bXgZlZNXVnC8xqEXetWZ50KRUF+nGgo2R5VXFdqT5gd0QMRsQR4BeMBryZWaplcnnuvOk6Fs6Z8pJkzVUS6PuB9ZI6JbUB24Dd48b8LaNn50hawegUTG8V6zQzqzsnzw9w4K2zid+uOGbKQI+IIWA7sBc4BDwVEQckPS5pa3HYXqBf0kHgh8C/joj+WhVtZlYP9uUKRCT7uH+pin5HiIg9wJ5x6x4r+TqArxT/mJk1hUy2wKK5s7ht5ZKkSwH8pKiZ2VUZfdy/wKa1K5jVWh9RWh9VmJk1mCOF8xw//V7dTLeAA93M7KqMtcuth/vPxzjQzcyuQne2QMeyedy8fEHSpbzPgW5mNk2DwyM8d7ifzQl9GPREHOhmZtP0k2OnOXdpqK6mW8CBbmY2bd3ZAhJ8Ym3yj/uXcqCbmU1TJlfgtpVLWDq/LelSLuNANzObhrMXB3n12Om6ul1xjAPdzGwanj/cz/BIsGV9fV0QBQe6mdm0ZHIF5re1cudN1yVdyhUc6GZm05DJFrircxlts+ovPuuvIjOzOtV36gK9hfNsrsPpFnCgm5lVLJOtv8f9SznQzcwq1J0rcP3iOaz/0MKkSynLgW5mVoGRkeDZXIFN61YgKelyynKgm5lV4MBbZzl1YbBup1vAgW5mVpHuXB6ATesc6GZmDS2TLXDrDYv40KK5SZcyIQe6mdkU3hsYpueNU3U93QIOdDOzKb34xkkGhkfq9v7zMRUFuqT7Jb0uKSfp0TLbPy8pL+nV4p/fr36pZmbJyGTztLW2sHH1sqRLmdSsqQZIagV2AJ8B+oD9knZHxMFxQ/9nRGyvQY1mZonqzhboWn0d89paky5lUpWcoW8EchHRGxEDwC7gwdqWZWZWH9559yI/P/FuXbbLHa+SQF8JHCtZ7iuuG++3JL0m6WlJHVWpzswsYftyxcf96+zzQ8up1kXR/wWsjojbgP8L/I9ygyQ9LKlHUk8+n6/Soc3Maqc7W+C6+bP56I2Lky5lSpUE+nGg9Ix7VXHd+yKiPyIuFRe/BfxKuR1FxM6I6IqIrvb2+v/fzsyaW0SQyRb4xLoVtLTU5+P+pSoJ9P3AekmdktqAbcDu0gGSPlyyuBU4VL0SzcySkX3nHO+8e4l7G2D+HCq4yyUihiRtB/YCrcCTEXFA0uNAT0TsBv5Y0lZgCDgJfL6GNZuZzYjuYrvcer//fMyUgQ4QEXuAPePWPVby9VeBr1a3NDOzZGWyedasWMDKpfOSLqUiflLUzKyMgaERXjhysiFuVxzjQDczK+Plo6e4MDDM5jrurjieA93MrIzubJ7WFnH32uVJl1IxB7qZWRmZbIGPdyxl8dzZSZdSMQe6mdk4py8M8NrxMw013QIOdDOzKzx7uJ8I6r7/+XgOdDOzcbqzBRbNmcXtHUuTLmVaHOhmZuNkcnnuXruc2a2NFZGNVa2ZWY292X+eYyffa7jpFnCgm5ld5v3H/Rvsgig40M3MLpPJFli5dB6dKxYkXcq0OdDNzIqGhkfYd7jA5nUrkOq/Xe54DnQzs6LXjp/h3YtDDdW/pZQD3cysKJMtIMGmBpw/Bwe6mdn7MtkCH71xMcsWtCVdylVxoJuZAecuDfHy0VNsaZAPsyjHgW5mBrzQ28/QSLClQadbwIFuZgaM3n8+d3YLv7L6uqRLuWoOdDMzIJMrsLFzOXNmtSZdylVzoJtZ03v7zHvk3jnX0NMt4EA3MyMz9rh/g95/PsaBbmZNrztbYMXCOdx6w6KkS7kmFQW6pPslvS4pJ+nRScb9lqSQ1FW9Es3MamdkJNiXK7B53fKGfNy/1JSBLqkV2AE8AGwAHpK0ocy4RcCXgReqXaSZWa0cOnGW/vMDbG7g+8/HVHKGvhHIRURvRAwAu4AHy4z7j8B/Bi5WsT4zs5oamz9vxP7n41US6CuBYyXLfcV175N0J9AREX832Y4kPSypR1JPPp+fdrFmZtWWyRW45fqFXL94btKlXLNrvigqqQX4OvCnU42NiJ0R0RURXe3tjf/rjZk1touDw7x45CSb16UjjyoJ9ONAR8nyquK6MYuAjwHPSHoDuBvY7QujZlbvet44xaWhkVRMt0Blgb4fWC+pU1IbsA3YPbYxIs5ExIqIWB0Rq4Hnga0R0VOTis3MqqQ7l2d2q7hrzbKkS6mKKQM9IoaA7cBe4BDwVEQckPS4pK21LtDMrFYy2QJ33nQd89tmJV1KVVT0XUTEHmDPuHWPTTD2k9delplZbRXOXeLAW2f5V79+S9KlVI2fFDWzprQvN/a4fzouiIID3cyaVCZbYMm82fzyyiVJl1I1DnQzazoRQSZXYNO65bS2NPbj/qUc6GbWdA7nz/P2mYupuf98jAPdzJpOJjv6pHpa7j8f40A3s6aTyRW4efl8OpbNT7qUqnKgm1lTGRwe4fnek2xu8E8nKseBbmZN5dVjpzl3aSh10y3gQDezJtOdLdAiuGetA93MrKF1Z/PctmopS+bNTrqUqnOgm1nTOPPeID85djqV0y3gQDezJvLc4X5GArak6HH/Ug50M2samVyeBW2t3HHT0qRLqQkHupk1jUy2wN1rljO7NZ3Rl87vysxsnGMnL/BG/wU2p3T+HBzoZtYkMsV2uWm9IAoVfsCFXenk+QFeOXqq7LaIiV83ySZikhdO9rrxSnvHSZpkW/mvR8eVbiz/+nL7vxaTff/XolyNV34f47driu3lDjT5Piba14T7Y4LaJxh8NX8TlbzjU/211OLvbTr/riod+oOfneCGxXNZ277wKquqfw70q/Sl773Mvlx/0mWY2TRs+9WOqp6E1BsH+lX4ybHT7Mv188VPruWBj90w4bjpnKFVopLXTfrbQcm2KDk3G/+a0sXSs6/xu778dcHVnSNerto/a+Xfj5h0zOTfZ/kz0qle88G4MhsmHFtm3XT2Wxw/1Xs60b/Ty8ZMuY/qmc75/nR/Obi9Iz0fZlGOA/0qPPGjwyyeO4t/8cm1LJqbvqfNzKwx+aLoNB3On+PvD5zg9+652WFuZnWlokCXdL+k1yXlJD1aZvsjkn4q6VVJGUkbql9qffjmj3tpa23h85/oTLoUM7PLTBnoklqBHcADwAbgoTKB/d2I+OWI+DjwZ8DXq15pHfjHsxf5/svH+e2uVbQvmpN0OWZml6nkDH0jkIuI3ogYAHYBD5YOiIizJYsLmN51jYbxZOYIQyMjPLxlbdKlmJldoZKLoiuBYyXLfcBd4wdJ+iPgK0Ab8KlyO5L0MPAwwE033TTdWhN15r1BvvPCUT53243ctDxdH1tlZulQtYuiEbEjItYC/xb49xOM2RkRXRHR1d7eWN3Ovv38m5y7NMQj961JuhQzs7IqCfTjQEfJ8qriuonsAn7zWoqqNxcHh/mrfUe495Z2Pnpjuu9jNbPGVUmg7wfWS+qU1AZsA3aXDpC0vmTxc0C2eiUm7+mX+iicG+CL93nu3Mzq15Rz6BExJGk7sBdoBZ6MiAOSHgd6ImI3sF3Sp4FB4BTwz2pZ9EwaGh5h5497ub1jKXevWZZ0OWZmE6roSdGI2APsGbfusZKvv1zluurGD352gqMnL/C1z34k1T0gzKzx+UnRSUQET/zoMGvaF/DrG65Puhwzs0k50CfRnS1w4K2z/OG9a2hp8dm5mdU3B/oknvjRYa5fPIffvGNl0qWYmU3JgT6Bnxw7zbOH+/nC5k7mzGpNuhwzsyk50Ccw1iL3oY2N9USrmTUvB3oZbpFrZo3IgV6GW+SaWSNyoI/jFrlm1qgc6OO4Ra6ZNSoHegm3yDWzRuZAL+EWuWbWyBzoRW6Ra2aNzoFe5Ba5ZtboHOi4Ra6ZpYMDnQ9a5H7xvrVukWtmDavpAz0i+MYzbpFrZo2v6QO9O1vg4NtneeTetW6Ra2YNrekD/RvPjLbIffCOG5MuxczsmjR1oL967DTP9fbz+5vXuEWumTW8pg70J54ptsi9yy1yzazxNW2gH86fY+/BE/zTe1azcE5Fn5VtZlbXKgp0SfdLel1STtKjZbZ/RdJBSa9J+n+Sbq5+qdW180fFFrmbViddiplZVUwZ6JJagR3AA8AG4CFJG8YNewXoiojbgKeBP6t2odV04sxFvv9KH7/T1cGKhW6Ra2bpUMkZ+kYgFxG9ETEA7AIeLB0QET+MiAvFxeeBVdUts7qe3HeE4ZHgD7a4CZeZpUclgb4SOFay3FdcN5EvAD8ot0HSw5J6JPXk8/nKq6yiM+8N8t0XjvIbbpFrZilT1Yuikn4X6AL+vNz2iNgZEV0R0dXe3l7NQ1dsrEXuH7pFrpmlTCW3dxwHOkqWVxXXXUbSp4F/B9wXEZeqU151uUWumaVZJWfo+4H1kjoltQHbgN2lAyTdAfwlsDUi3ql+mdXhFrlmlmZTBnpEDAHbgb3AIeCpiDgg6XFJW4vD/hxYCPy1pFcl7Z5gd4lxi1wzS7uKnqiJiD3AnnHrHiv5+tNVrqvqxlrkfu2zH3GLXDNLpaZ4UtQtcs2sGTRFoLtFrpk1g6YIdLfINbNmkPpAd4tcM2sWqQ90t8g1s2aR6kB3i1wzayapDnS3yDWzZpLaQHeLXDNrNqkNdLfINbNmk8pAP3NhkO88/6Zb5JpZU0lloH/7hTc5PzDsFrlm1lRSF+hjLXLvc4tcM2syqQv0vy62yH3ELXLNrMmkKtCHhkf45o97+bhb5JpZE0pVoO8ptsh95L61bpFrZk0nNYEeETzhFrlm1sRSE+g/dotcM2tyqQn0J545zA2L57pFrpk1rVQE+liL3C9s7nSLXDNrWqkIdLfINTNLQaC7Ra6Z2aiKAl3S/ZJel5ST9GiZ7fdKelnSkKR/Uv0yJ+YWuWZmo6YMdEmtwA7gAWAD8JCkDeOGHQU+D3y32gVOxi1yzcw+UMkcxUYgFxG9AJJ2AQ8CB8cGRMQbxW0jNahxQm6Ra2b2gUqmXFYCx0qW+4rrpk3Sw5J6JPXk8/mr2cX73CLXzOxyM3pRNCJ2RkRXRHS1t7df077cItfM7HKVBPpxoKNkeVVxXWLcItfM7EqVBPp+YL2kTkltwDZgd23Lmpxb5JqZXWnKQI+IIWA7sBc4BDwVEQckPS5pK4CkX5XUB/w28JeSDtSqYLfINTMrr6IncSJiD7Bn3LrHSr7ez+hUTM2Ntcj92mc/4ha5ZmYlGu5J0YVzWvnMhuvdItfMbJyGe1b+U7dez6dudZibmY3XcGfoZmZWngPdzCwlHOhmZinhQDczSwkHuplZSjjQzcxSwoFuZpYSDnQzs5RQRCRzYCkPvHmVL18BFKpYTqPz+3E5vx8f8HtxuTS8HzdHRNn+44kF+rWQ1BMRXUnXUS/8flzO78cH/F5cLu3vh6dczMxSwoFuZpYSjRroO5MuoM74/bic348P+L24XKrfj4acQzczsys16hm6mZmN40A3M0uJhgt0SfdLel1STtKjSdeTFEkdkn4o6aCkA5K+nHRN9UBSq6RXJP3vpGtJmqSlkp6W9HNJhyTdk3RNSZH0J8Wfk59J+p6kuUnXVAsNFeiSWoEdwAPABuAhSRuSrSoxQ8CfRsQG4G7gj5r4vSj1ZUY/zNzgvwF/HxG3ArfTpO+LpJXAHwNdEfExoBXYlmxVtdFQgQ5sBHIR0RsRA8Au4MGEa0pERLwdES8Xv36X0R/WlclWlSxJq4DPAd9KupakSVoC3Av8d4CIGIiI08lWlahZwDxJs4D5wFsJ11MTjRboK4FjJct9NHmIAUhaDdwBvJBsJYn7C+DfACNJF1IHOoE88FfFKahvSVqQdFFJiIjjwH8BjgJvA2ci4v8kW1VtNFqg2ziSFgJ/A/zLiDibdD1JkfQbwDsR8VLStdSJWcCdwDci4g7gPNCU15wkXcfob/KdwI3AAkm/m2xVtdFogX4c6ChZXlVc15QkzWY0zL8TEd9Pup6EbQK2SnqD0am4T0n6drIlJaoP6IuIsd/anmY04JvRp4EjEZGPiEHg+8AnEq6pJhot0PcD6yV1Smpj9MLG7oRrSoQkMTo/eigivp50PUmLiK9GxKqIWM3ov4t/iIhUnoVVIiJOAMck/VJx1a8BBxMsKUlHgbslzS/+3PwaKb1APCvpAqYjIoYkbQf2Mnql+smIOJBwWUnZBPwe8FNJrxbXfS0i9iRYk9WXLwHfKZ789AL/POF6EhERL0h6GniZ0bvDXiGlLQD86L+ZWUo02pSLmZlNwIFuZpYSDnQzs5RwoJuZpYQD3cwsJRzoZmYp4UA3M0uJ/w8eklcAmw1+CgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "val_inputs, val_outputs = generate_data(N_val, 0)\n",
    "plt.plot(val_outputs[0, :].numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Monomial fit to each component\n",
    "\n",
    "We will initialize the parameters in our LLCP model by fitting monomials to the training data, without enforcing the monotonicity constraint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_c = cp.Variable(shape=(m,1))\n",
    "theta = cp.Variable(shape=(n, m))\n",
    "inputs_np = train_inputs.numpy()\n",
    "log_outputs_np = np.log(train_outputs.numpy()).T\n",
    "log_inputs_np = np.log(inputs_np).T\n",
    "offsets = cp.hstack([log_c]*N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cp_preds = theta.T @ log_inputs_np + offsets\n",
    "objective_fn = (1/N) * cp.sum_squares(cp_preds - log_outputs_np)\n",
    "lstq_problem = cp.Problem(cp.Minimize(objective_fn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstq_problem.is_dcp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------------------------------------------------------\n",
      "           OSQP v0.6.0  -  Operator Splitting QP Solver\n",
      "              (c) Bartolomeo Stellato,  Goran Banjac\n",
      "        University of Oxford  -  Stanford University 2019\n",
      "-----------------------------------------------------------------\n",
      "problem:  variables n = 1210, constraints m = 1000\n",
      "          nnz(P) + nnz(A) = 23000\n",
      "settings: linear system solver = qdldl,\n",
      "          eps_abs = 1.0e-05, eps_rel = 1.0e-05,\n",
      "          eps_prim_inf = 1.0e-04, eps_dual_inf = 1.0e-04,\n",
      "          rho = 1.00e-01 (adaptive),\n",
      "          sigma = 1.00e-06, alpha = 1.60, max_iter = 10000\n",
      "          check_termination: on (interval 25),\n",
      "          scaling: on, scaled_termination: off\n",
      "          warm start: on, polish: on, time_limit: off\n",
      "\n",
      "iter   objective    pri res    dua res    rho        time\n",
      "   1   0.0000e+00   3.30e+00   1.22e+04   1.00e-01   3.06e-03s\n",
      "  50   1.0014e-02   1.72e-07   1.64e-07   1.75e-03   7.37e-03s\n",
      "plsh   1.0014e-02   1.56e-15   1.17e-14   --------   9.68e-03s\n",
      "\n",
      "status:               solved\n",
      "solution polish:      successful\n",
      "number of iterations: 50\n",
      "optimal objective:    0.0100\n",
      "run time:             9.68e-03s\n",
      "optimal rho estimate: 8.77e-05\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.010014212812318733"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lstq_problem.solve(verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = torch.exp(torch.tensor(log_c.value)).squeeze()\n",
    "lstsq_val_preds = []\n",
    "for i in range(N_val):\n",
    "    inp = val_inputs[i, :].numpy()\n",
    "    pred = cp.multiply(c,cp.gmatmul(theta.T.value, inp))\n",
    "    lstsq_val_preds.append(pred.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A_param = cp.Parameter(shape=(m, n))\n",
    "c_param = cp.Parameter(pos=True, shape=(m,))\n",
    "x_slack = cp.Variable(pos=True, shape=(n,))\n",
    "x_param = cp.Parameter(pos=True, shape=(n,))\n",
    "y = cp.Variable(pos=True, shape=(m,))\n",
    "\n",
    "prediction = cp.multiply(c_param, cp.gmatmul(A_param, x_slack))\n",
    "objective_fn = cp.sum(prediction / y + y / prediction)\n",
    "constraints = [x_slack == x_param]\n",
    "for i in range(m-1):\n",
    "    constraints += [y[i] <= y[i+1]]\n",
    "problem = cp.Problem(cp.Minimize(objective_fn), constraints)\n",
    "problem.is_dgp(dpp=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "A_param.value = np.random.randn(m, n)\n",
    "x_param.value = np.abs(np.random.randn(n))\n",
    "c_param.value = np.abs(np.random.randn(m))\n",
    "\n",
    "layer = CvxpyLayer(problem, parameters=[A_param, c_param, x_param], variables=[y], gp=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(epoch 0) train / val (0.0018 / 0.0014) \n",
      "(epoch 1) train / val (0.0017 / 0.0014) \n",
      "(epoch 2) train / val (0.0017 / 0.0014) \n",
      "(epoch 3) train / val (0.0017 / 0.0014) \n",
      "(epoch 4) train / val (0.0017 / 0.0014) \n",
      "(epoch 5) train / val (0.0017 / 0.0014) \n",
      "(epoch 6) train / val (0.0016 / 0.0014) \n",
      "(epoch 7) train / val (0.0016 / 0.0014) \n",
      "(epoch 8) train / val (0.0016 / 0.0014) \n",
      "(epoch 9) train / val (0.0016 / 0.0014) \n"
     ]
    }
   ],
   "source": [
    "torch.random.manual_seed(1)\n",
    "A_tch = torch.tensor(theta.T.value)\n",
    "A_tch.requires_grad_(True)\n",
    "c_tch = torch.tensor(np.squeeze(np.exp(log_c.value)))\n",
    "c_tch.requires_grad_(True)\n",
    "train_losses = []\n",
    "val_losses = []\n",
    "\n",
    "lam1 = torch.tensor(1e-1)\n",
    "lam2 = torch.tensor(1e-1)\n",
    "\n",
    "opt = torch.optim.SGD([A_tch, c_tch], lr=5e-2)\n",
    "for epoch in range(10):\n",
    "    preds = layer(A_tch, c_tch, train_inputs, solver_args={'acceleration_lookback': 0})[0]\n",
    "    loss = (preds - train_outputs).pow(2).sum(axis=1).mean(axis=0)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        val_preds = layer(A_tch, c_tch, val_inputs, solver_args={'acceleration_lookback': 0})[0]\n",
    "        val_loss = (val_preds - val_outputs).pow(2).sum(axis=1).mean(axis=0)\n",
    "\n",
    "    print('(epoch {0}) train / val ({1:.4f} / {2:.4f}) '.format(epoch, loss, val_loss))\n",
    "    train_losses.append(loss.item())\n",
    "    val_losses.append(val_loss.item())\n",
    "    \n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    with torch.no_grad():\n",
    "        c_tch = torch.max(c_tch, torch.tensor(1e-8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    train_preds_tch = layer(A_tch, c_tch, train_inputs)[0]\n",
    "    train_preds = [t.detach().numpy() for t in train_preds_tch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    val_preds_tch = layer(A_tch, c_tch, val_inputs)[0]\n",
    "    val_preds = [t.detach().numpy() for t in val_preds_tch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAEICAYAAABfz4NwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de3xU9Z3/8dcn94QA4RIgJEBCBDEgBAmgoBRFEMRFq7XFXrZ0t1Lbetm2trXdrlpqW7u6vVl3LW1d+9tqrVJrqaIoeBdBAnIPlxACSbgFJIGQe+b7+2MGCCEkATJzJsn7+Xjk8ZhzzvfM953JZD7zPVdzziEiIhLhdQAREQkPKggiIgKoIIiISIAKgoiIACoIIiISEOV1gPPVt29fl56e7nUMEZEOZc2aNYecc8nNLeuwBSE9PZ3c3FyvY4iIdChmtvtsy7TJSEREABUEEREJUEEQERGgA+9DaE5dXR3FxcVUV1d7HUUaiYuLIy0tjejoaK+jiEgLOlVBKC4upnv37qSnp2NmXscRwDnH4cOHKS4uJiMjw+s4ItKCTrXJqLq6mj59+qgYhBEzo0+fPhq1iXQAnaogACoGYUh/E5GOodMVBBEROT8qCO0sMTHxjHkPPvggjz766Bnz9+/fz9y5c8nMzGTcuHFcf/31bN++ncLCQuLj48nOziYrK4s77rgDn88XivgiEsYOHDjAwoUL2bt3b1CeXwXBI845PvnJTzJ16lR27tzJmjVr+OlPf8qBAwcAyMzMZN26dWzYsIEtW7bw4osvepxYRLwWFxdHRY8elAXpC2KnOsqoI3nzzTeJjo7mjjvuODlvzJgxABQWFp6cFxUVxaRJk8jPzw91RBEJMwfq6/nWtm0kXHQRWWlp7f78nbYg/Nurr7Ju//52fc7sAQP45cyZ7fJcmzZtYty4ca22q6ysZPny5SxYsKBd+hWRjsk5x8srVxILTAvSIdydtiB0dDt37iQ7Oxsz48Ybb2TWrFleRxIRD5WWlnI0N5cr4+O5qHfvoPTRaQtCe32TD5aRI0eyaNGisy4/sQ9BRARg165dAAwN4om32qnskWuuuYaamhoWLlx4ct6GDRt49913PUwlIuFq/bZtlANTLr44aH2oILSzyspK0tLSTv78/Oc/B+Chhx46bb6Z8be//Y1ly5aRmZnJyJEj+d73vseAAQM8/g1EJNw459hXXMxu4NrMzKD102k3GXnlbOcLPPjgg2fMGzhwIM8991yz7Tdt2tSesUSkAzt8+DDU1TGmz24GvDcVrlwEient3o9GCCIiYa5bz548GRnJtP774Og2iB8YlH40QhARCXMri4vZ01DPWNsI/adBZExQ+tEIQUQkjDnneHP5cq6OPkxCbQmkBO8IShUEEZEwduTIESJLSvhC3xL/jJTrgtaXNhmJiISxLTt2ADA5aQ/ED4fE4N1oSiMEEZEwtjYvj2qrI9O3OaibiyBEBcHMZprZNjPLN7P7mln+CzNbF/jZbmZlocgVDM1d/vpCPPXUU0G71K2IhDfnHIf27aN7/B4ifdUwsIMXBDOLBB4HZgFZwG1mltW4jXPuG865bOdcNvAY8EKwc3UU4VIQ6uvrvY4g0uVUVVVRXV/PZX2KISIW+n0iqP2FYoQwAch3zhU452qBZ4EbW2h/G/DnEOQKukceeYTx48czevRoHnjggZPzb7rpJsaNG8fIkSNPXrqioaGBefPmMWrUKC699FJ+8YtfsGjRInJzc/nc5z5HdnY2VVVVpz3/r3/9a7Kyshg9ejRz584F/CewzJgxg5EjR/LlL3+ZIUOGcOjQIQoLCxk1atTJdR999NGTJ8v97ne/Y/z48YwZM4ZbbrmFyspKAObNm8cdd9zBxIkT+c53vsPOnTuZOXMm48aN46qrrmLr1q0APP/884waNYoxY8YwZcqUoL2eIl1NWUMDj/h8HB10A4z5MUQlBLW/UOxUTgWKGk0XAxOba2hmQ4AM4I2zLJ8PzAcYPHhwqx0/9dRTZ8wbOXIk48ePp66ujqeffvqM5dnZ2WRnZ1NZWXnGWcTz5s1rtc8TXnvtNXbs2MGHH36Ic445c+bwzjvvMGXKFJ588kl69+5NVVUV48eP55ZbbqGwsJCSkpKTZyiXlZWRlJTEb37zGx599FFycnLO6OPhhx9m165dxMbGUlbm38r2wx/+kCuvvJL777+fl19+mT/84Q+tZr355pu5/fbbAfjBD37AH/7wB+666y4AiouLWbFiBZGRkUybNo0nnniCYcOGsWrVKr72ta/xxhtvsGDBApYuXUpqaurJHCJy4ZYXFABw0ajbICUl6P2F21FGc4FFzrmG5hY65xYCCwFycnJcKIOdq9dee43XXnuNsWPHAlBRUcGOHTuYMmUKv/71r/nb3/4GQFFRETt27ODiiy+moKCAu+66i9mzZzNjxoxW+xg9ejSf+9znuOmmm7jpppsAeOedd3jhBf8Wt9mzZ9OrV69Wn2fTpk384Ac/oKysjIqKCq677tRhbbfeeiuRkZFUVFSwYsUKbr311pPLampqAJg8eTLz5s3j05/+NDfffHMbXyERac2W117jywmHyI4rBTcAgnSV0xNCURBKgEGNptMC85ozF/h6e3Xc0jf66OjoFpcnJCSc04igKecc3/ve9/jKV75y2vy33nqLZcuW8cEHH5CQkMDUqVOprq6mV69erF+/nqVLl/LEE0/w3HPP8eSTT7bYx8svv8w777zDP/7xD3784x+zcePGs7aNioo67TpL1dXVJx/PmzePF198kTFjxvDUU0/x1ltvnVzWrVs3wH+NpqSkpGYvyf3EE0+watUqXn75ZcaNG8eaNWvo06dPi9lFpGVlZWXEVVby1fT3iHj773Dj7qD3GYp9CKuBYWaWYWYx+D/0FzdtZGYjgF7AByHIFHTXXXcdTz75JBUVFQCUlJRw8OBBysvL6dWrFwkJCWzdupWVK1cCcOjQIXw+H7fccgsPPfQQa9euBaB79+4cO3bsjOf3+XwUFRVx9dVX87Of/Yzy8nIqKiqYMmUKzzzzDACvvPIKR44cAaB///4cPHiQw4cPU1NTw0svvXTyuY4dO0ZKSspZN6MB9OjRg4yMDJ5//nnAX/DWr18P+G/mM3HiRBYsWEBycjJFRUXNPoeItN3KzZsBWD/wPpj0TNBHBxCCEYJzrt7M7gSWApHAk865zWa2AMh1zp0oDnOBZ51zYb0pqK1mzJhBXl4eV1xxBeA/HPVPf/oTM2fO5IknnuCSSy7h4osv5vLLLwf8BeNLX/rSyW/xP/3pT4FTO3bj4+P54IMPiI+PB/w7oT//+c9TXl6Oc467776bpKQkHnjgAW677TZGjhzJpEmTTu5riY6O5v7772fChAmkpqYyYsSIk1l/9KMfMXHiRJKTk5k4cWKzBQjg6aef5qtf/SoPPfQQdXV1zJ07lzFjxvDtb3+bHTt24Jxj2rRpJ+8NLSLnb/3WrVQBV152LYRoxG0d9fM3JyfH5ebmnjYvLy+PSy65xKNE4Sk9PZ3c3Fz69u3raQ79bUTOzfd/8hOGJKxj/g3TsYtub7fnNbM1zrkzj1JBZyqLiISduvp68nw+ZvX9ENv1fyHrN9yOMpJ2VlhY6HUEETlH6w8eZAXlDI7YDQPnh6xfjRBERMLM61u2MDNhp38iyNcvakwFQUQkzJR9+CFf7b4L4vpBr+yQ9auCICISRg4dOUJCXQ2jEvJhwHVgofuYVkEQEQkjyz/6iJTYfSRaRdCvbtqUCkI7Kisr47//+7+9jiEiHdjmHTvISMjHYTBgekj7VkFoR2crCLp0tIi0VeWhQwztXoD1zoG45JD2rYLQju677z527txJdnY248eP56qrrmLOnDlkZWW1ePnps11WWkS6lrLqav5eX0tDfB9Ia+kuAcHRuc9DWDa19TapN8Al955qP3Se/6f6ELz3qdPbXvtWi0/18MMPs2nTJtatW8dbb73F7Nmz2bRpExkZGS2eDzB//vxmLystIl3L24WF5BPBvsv/yvD09JD337kLgscmTJhARkbLN8Ru6bLSItK1vL12LZdG+bg8Lc2T/jt3QWjlG32L7eP6nvv6TZy4dDSc/fLTLV1WWkS6Flewk7cGP0bsplrI/mnI+9c+hHZ0tktVw9kvP93SZaVFpOvYeeAAfXy1bIqYCn0neZKhc48QQqxPnz5MnjyZUaNGER8fT//+/U8ua+ny02e7rLSIdB2vr11Lg4uiduR/QNpYTzKoILSzEzenac7dd9/N3Xfffcb8jIwMXn311WDGEpEwt72ggIvjipgyouX9jsGkTUYiIh5zztFQdoB/TXuKmK0/9iyHCoKIiMe2Hz5MQcxOoqwBUq7zLEenKwgd9Q5wnZn+JiItW1ZQwIxu+fgiEyD5Ss9ydKqCEBcXx+HDh/UBFEaccxw+fJi4uDivo4iErfVr1vCpbjux/ldDpHf/K51qp3JaWhrFxcWUlpZ6HUUaiYuLI82jE21Ewl2Dz8eAsq2k9Dwc0pvhNCckBcHMZgK/AiKB3zvnHm6mzaeBBwEHrHfOffZc+4mOjm71zGARkXCyavduJsbn+ydCfLnrpoJeEMwsEngcmA4UA6vNbLFzbkujNsOA7wGTnXNHzKxfsHOJiISDN9ev51MJ+VTHDCKu+0WeZgnFPoQJQL5zrsA5Vws8CzS9jN/twOPOuSMAzrmDIcglIuK5osJ8MhJ2ET1ottdRQlIQUoGiRtPFgXmNDQeGm9n7ZrYysInpDGY238xyzSxX+wlEpKOrrq8nrnYrMRF1RKZe73WcsDnKKAoYBkwFbgN+Z2ZJTRs55xY653KccznJyaG9cYSISHtbUVTEr6oGsfzSN2DAtV7HCUlBKAEGNZpOC8xrrBhY7Jyrc87tArbjLxAiIp3Wsvx8Is2YMHwSRMV7HSckBWE1MMzMMswsBpgLLG7S5kX8owPMrC/+TUgFIcgmIuKZY5veZW3KcyRUbGm9cQgEvSA45+qBO4GlQB7wnHNus5ktMLM5gWZLgcNmtgV4E/i2c+5wsLOJiHilrLqagbUlDIkrJjIyPE4JC0kK59wSYEmTefc3euyAbwZ+REQ6vTfy86mqHswL8b/nSz1Htb5CCITLTmURkS7lvY0bicQx8dJLwczrOIAKgoiIJ+IPvMc3M/6Li3qEzyH04bHhSkSkCyk5epT+kZvpFnmciN5ZXsc5SSMEEZEQW75rFxO75VPVcxzE9vY6zkkqCCIiIZa7dTXjY/cSP/gGr6OcRgVBRCSEnHMM3r+MCHPU9b3G6zinUUEQEQmhbYcPc3nsVip93YgdMMnrOKdRQRARCaFl27cxJiGf0uixEBHpdZzTqCCIiITQvu3L6B51nOgwuLppUyoIIiIh0uDzMbTifQCSsj7tcZoz6TwEEZEQWbtvH/9zLJOs9Iu5onem13HOoIIgIhIiywoKWFM7gMyr7/U6SrNUEEREQmTf9uU81LuERNfgdZRmaR+CiEgIVNXVcU3tEu5N+n+4MC0IGiGIiITAiqIiFpdOp6z+KuZ17+V1nGapIIiIhMDyggIGuBiShkz0OspZqSCIiIRAfMGfmdI7nz4ZN3od5axUEEREgqysupqr7R1GdC8jLuMir+OclXYqi4gE2TsFWxkXX0j14BtITEz0Os5ZqSCIiARZ0Y5/EB9RT//hn/I6SotCUhDMbKaZbTOzfDO7r5nl88ys1MzWBX6+HIpcIiKh0Kd0OXW+KA5xiddRWhT0fQhmFgk8DkwHioHVZrbYObelSdO/OOfuDHYeEZFQKjl6lMujNrO7agj9evT1Ok6LQjFCmADkO+cKnHO1wLNA+O5mFxFpRyu3vU96zCGKfKPo0aOH13FaFIqCkAoUNZouDsxr6hYz22Bmi8xsUHNPZGbzzSzXzHJLS0uDkVVEpF0dLVwMgK/ftR4naV247FT+B5DunBsNvA78sblGzrmFzrkc51xOcnJySAOKiJwr5xyDjr5HWV1P+mZc6XWcVoWiIJQAjb/xpwXmneScO+ycqwlM/h4YF4JcIiJBte3wYQqru7HFdwXpGRlex2lVKArCamCYmWWYWQwwF1jcuIGZpTSanAPkhSCXiEhQLSso4PZDsxhw0zP07NnT6zitCvpRRs65ejO7E1gKRAJPOuc2m9kCINc5txi428zmAPXAx8C8YOcSEQm2D3eu55Ie3RnaKzwvZteUOee8znBecnJyXG5urtcxRESa1eDzkffHNCLqknCTn2fkyJFeRwLAzNY453KaW6ZrGYmIBMGavXtZWjae4Q3xzBg40Os4baKCICISBMt37WLj0cuI796dpKQkr+O0Sbgcdioi0qkc3vUSl8UcJTMjAzPzOk6baIQgItLOqmqr+b79hqJeF9Ew5HNex2kzjRBERNrZxq2v0juykvKUmVx0Ufje/6ApjRBERNpZ+a4X8TnjsmvuJrF7eF+/qDGNEERE2lm/o+9R0DCIBl/43gynOSoIIiLtqKx8L6MiCthbns6OHTu8jnNOVBBERNrRts3PEWmO/MqLSE9P9zrOOVFBEBFpRw17X+F4QxxHIofTp08fr+OcExUEEZH24hxDq1ZRUJ3J4CFDO8z5ByeoIIiItJP9BzbRyyrYVZHJkCFDvI5zzlQQRETayWsHfPQu+C4D/+knjB492us450znIYiItJNlBQUkxCdxWWYWER1scxFohCAi0i5cXQVfK7+XH3UrYMf27V7HOS9tHiEEbmgTBawD1jnnOuZvLCISBDtLNhPZUENCRSVlZWVexzkvbR4hOOfuB34FlAOfNLPfBS2ViEgH8+pB48vFt7OramiH3KEMbRghmNnrwL3OufXOuQP4b4W5NOjJREQ6kHcLtnFpTAxxERH079/f6zjnpS0jhO8CvzSz/zWzlGAHEhHpaOrLt/PHhn/h+vjNDBkypMOdf3BCqwXBObfWOXc18BLwqpk9YGbxwY8mItIxFG97njirpyImk6FDh3od57y1aR+C+cvdNuB/gLuAHWb2hbZ2YmYzzWybmeWb2X0ttLvFzJyZNXsDaBGRcFRXvIT82l588ovfYcKECV7HOW+tFgQzex8oAX4BpALzgKnABDNb2Ib1I4HHgVlAFnCbmWU10647cA+wqu3xRUQ81lBDWlUuuW40yd26eZ3mgrRlhDAfSHXOTXfO/Ydz7iXnXL5z7i7gqjasPwHId84VOOdqgWeBG5tp9yPgZ0B1W8OLiHitZv/bxFst9RUZvP76617HuSBt2Yew2TnnzrJ4dhv6SAWKGk0XB+adZGaXAYOccy+39ERmNt/Mcs0st7S0tA1di4gE177ti6jxRbLnSAoJCQlex7kgF3SmsnOu4EIDmFkE8HPgW23ob6FzLsc5l5OcnHyhXYuIXLDYg8tYVz2YOhfb4e5/0FQoLl1RAgxqNJ0WmHdCd2AU8JaZFQKXA4u1Y1lEwl5lCSkNu9hTn0VMTAwpKR37yPxQFITVwDAzyzCzGGAusPjEQudcuXOur3Mu3TmXDqwE5jjnckOQTUTkvB3f/Q8ADlUNZ/DgwUREdOzLwwU9vXOuHrgT/9nNecBzzrnNZrbAzOYEu38RkWB5p+4S/uXAHHoMu4axY8d6HeeCheTy1865JcCSJvPuP0vbqaHIJCJyoV4qquD56sv57axZREdGeh3ngnXs8Y2IiFfK80gqeYYbByRRV1PjdZp2oYIgInIeynf+hQXd/syYI4f461//6nWcdqGCICJyHl60Gxm9+y4qj/k67OWum1JBEBE5D8sLC0m0AQAqCCIiXZUrXsysww8ztUcUUVFRpKamtr5SB6CCICJyjsp3PM302I0k1UaQlpZGVFRIDtgMus7xW4iIhIrzEX1wGUsqM5k15yb6xXee28OoIIiInIsj6+nW8DFrmcVnMzO9TtOutMlIROQcNOx9BYC4+Als3LjR4zTtSwVBROQcHN+9mLXVA+heBmvWrPE6TrtSQRARaavacrqV57Ks8iJqyso6zeGmJ6ggiIi01YE3iKSBgxHZOOdUEEREuqr6klc46oslPjqLiIgIBg0a1PpKHYiOMhIRaaOdDf15sSyH5IRYElNTiY6O9jpSu1JBEBFpo/+t+gQ/PxLDx/M/T6yZ13HanQqCiEhbHC/ivV3buDwtjcSYGK/TBIUKQjjY9muoPthym25D4KLb/Y+3/gri+kH6bf7pjQugobrl9ZMubdT+h9B7HKTeAA21sPGBJo0bffM58S2o7yRInQ2+On9/KTOg31VQ8zFs+2UzHTb59jRgOvS70t9++2OQ9knoNRqO74GCp1rODpB2U6D9bij4I2R8ARIzoGwzFLXh0sOntV8Ew74GcclwaCXsW9r6+ifbr4J9r0HWtyEyDvYvh0MftP77X/KtQPs34OM1/vUBihdD+eazdBp4jogouORe/+OSl6BqL1w03z+9+y/+16Sl/mOSTr13dj8HruHUe2Hn/0Ltxy3/7glpMOQz/scFT0FsX/97B2DbY+Br5V4A3YdD2pxT7XuOhAHX+N9LW3/R8rrgf68OmHaqfb9PQN+JUFsG+b9rff1+U0613/l7SJkFSSOhshh2P9v6+gOvh55Z1Hz0fX4T9Tp/j3qERYsWccstt2CdbJSgguAFXwOs/Yb/nzppFOQvhKNbW14nefKpf+qdv/evd+KfevtjUFfe8vqDbmnU/nH/B2TqDeDqYevPGzV0jR42ejz8rlMFYctPILqHvyDUHoFNDzXpzHGGqG7+glB7BDY+CIlDT33An1GQmpGYcaqAbHzA/3okZvg/TNuy/mntH4TBt54qCBsfbH39k+0/gI33w8V3+T/g9y2FvEdaX3/41wPtX4XtvzlVEPY8D4V/anndyPhTBWH3X+DQilMFIX8hHHij5fUTh5567+T/Fny1p94LeY/A0byW10++8lRB2PKf/vfeiYKw4T9af+8NvvVUQdjwHzB0XqAg1MO677a8LsDF9wQKQqB99sP+D/iaw7DuO62v37j9R9+GuP7+glCxyz/dmrj+0DOLVTHTebOiiJ5RFZTV1HS6YgBgzjXzz9sB5OTkuNzcXK9jnJ+yzfiWXs6aqK+SNPp2hg0bRkVFBcuXLz+j6ejRo8nIyKCsrIy33377jOWXXXYZgwYN4tChQ7z//vtnLJ8wYQIpKSns37+fVatWnbF80qRJJCcnU1xc3OxJNlOmTKFXr14UFhaybt26M5ZPmzaN7t27k5+f3+xZmzNnziQ+Pp6tW7eSl5d3epEx44bZs4mOjmLjxg3k5+88Y/2bbroJM+Ojj9axq7AwsL4DjMioKG6cMwdwrF69mqKiotPWjY2NZfbs2QB88MFK9u3ff1r/3RITuW7GDADeffddSktLT1u/Z8+eTJs2DYA33niDsvJycL5A/xH0TU5mypWTAceyZcuoqKgIrOnvo3///lxxxRUAvPLq61TX1GCuAcOHz6JJTU1lQs5YcD6WLFlCfX19o94dgwcPJjs7G5zjxZf8o5gIVwf48FksmZmZXJo1jPq6Wl555ZXT1gUYNmwYI0aMoLq6hqXL3w2sX4MBDRZLVlYWwzIGUlFxlLffOvO9NerSUQwZPJiy8mO8/b7/fy3SVQERNFgs48aNI61/EocOH2LlypVnrJ8zLocBAwZwoPQwK1evO7m+IwqfRTPpiitI7p3I3r17Wbt27RnrT75yMr2SerG7qIQdO3eTOnAgqSm96dGzD0RE+/8WDVVnrHcGi4bIGH/7+kqIjPWv72to2/qB9ne+9A+e3rCRbzY0cPnllzN9+vTW1w1DZrbGOZfT7DIVhNCrr6/n0f/6ERXVsDI2nryYGLr7fNxQWXlG29WxseRHR9OroYGZVWe+eVfExrI7OprkhgamN7P87bg4SqKiGFhfz9TqMzcrLY+L40BUFEPq65nUzPKl8fF8HBlJZl0dE5q5TeBLCQkci4hgRG0t2bW1ZyxfnJBAVUQEI2trGdXM8he6daPOjDE1NYyoqztj+XOJiWDGuOpqhjZZ3mDGC4mJAEysrmZwk+U1EREs7tYNgMlVVQw87QMXKiIieCWwfGpVFclNlpdFRPBaQgIA06uq6NXQcNryQ5GRvBG4sNmsykp6+HynLd8XFcU7cXEAzDl+nPgm/2t7oqL4ILD8luPHiW6yfGd0NKtjYwH4TEVF041QbIuO5qPYWCKd49bjx2lqc0wMG2NiiPP5uKmZ99b6mJgLeu99EBtLYeC9d20zy9+Ni6M48N77RDPvrTdOvPfq6pjUzHur8XvvE9XVnLhj8XEzDkZG8m5sLNUREf4i38q39fb4nDtcVcXNAwYwoqSEz372swwbNuyCn9MLnhcEM5sJ/AqIBH7vnHu4yfI7gK8DDUAFMN85t6Wl5+ywBaHqAMtWb+P9N99kZZ8+pA0e7HWiC9Ye76Bw+mLSGTcFdHg+H7FVVURXVhJz/DjRVVUcHDECzOhZXEzcsWPUJiT4f7p1oy4uDiJOP83qQv+uBlzl81Gwfj3f/e53iQ0U646mpYIQ9H0IZhYJPA5MB4qB1Wa2uMkH/jPOuScC7ecAPwdmBjtbyDmHe2s2qSVV7OHT/Pq22xjWp4/XqUQ6tHXr1pGXl0dJSQnHP/bvIO/Vqxd33303AHv27KFbt2707t37gotCbm4uic512GLQmlDsVJ4A5DvnCgDM7FngRuBkQXDOHW3Uvhvt86Uz/Ox7DTuyhrxj/0RsZqaKgUg7yM7OJjvbfymJ8vJySkpKqG20efLFF1/kyJEjxMXFkZqaSmpqKpmZmQw+j9F5Tk4OOTnNfrnuFEJREFKBxnv7ioGJTRuZ2deBbwIxwDXNPZGZzQfmA+f1x/SUc7D5R3xsyXz5WDarv3i914lEOhUzIykpiaSkpNPmf+Yzn6GkpOTkz7vvvktFRQWDBw/GOceLL75I//79SU1NJSUlhZiznGNQU1NDVFQUkZGRzS7vDMLmsFPn3OPA42b2WeAHwBebabMQWAj+fQihTXiBDr4Dpe/zw9LZ3JydQ2bv3l4nEukS+vfvT//+/bnssssAqK2tpS5wAEJFRQV79uxhw4YNgL+o9OvXj6lTpzJixAh8gQMFIiIiWLlyJStWrODee+/tdJesOCEUBaEEaHwFqLTAvLN5FhdWZS8AAAu6SURBVPifoCbygG/jAo7Xd2dveTaPTJnidRyRLismJubkKKB79+7cc889VFRUsHfv3pOjiBMf+Hv27OGZZ54hJSWFo0eP0rt3705bDCA0BWE1MMzMMvAXgrnAZxs3MLNhzrkdgcnZwA46k9IPiDj4Bh8cmcGAzEtIbzKkFRFvJSYmMnz4cIYPH37a/ISEBMaOHcvevXs5duyY/7yQTizoBcE5V29mdwJL8R92+qRzbrOZLQBynXOLgTvN7FqgDjhCM5uLOjLfpoeoakhgcfk47v1n7TsQ6Sj69evHrFmzAP+h0Z39kOSQ7ENwzi0BljSZd3+jx/eEIocnPv6IiH1LWHnkGiIuGkV6r15eJxKR89DZiwHoBjnBV72f/XUD+Hv5BL51vUYHIhK+VBCCrDDhCtJ334EbOZ4h2ncgImEsbA477Ywa9r7Oz1aV4SIi+e6113odR0SkRSoIwVJRSMRbM7nx8BRiRt1DWo8eXicSEWmRCkKQNMSl8X/7/5mNlX35zhebPfFaRCSsqCAEyesrVrD7WDo2fDipGh2ISAegncpB4FvzLeI33U8J8M3ADVpERMKdCkJ7q9wL2x4j0SroefHFDNToQEQ6CG0yam95j+Bo4DtlU/jT5zU6EJGOQyOE9lR9EN+O3/Kno6PJvuw6Urp39zqRiEibaYTQjnx5/wUN1bxcNoXHJk/2Oo6IyDlRQWgvNR/j2/oYeRUjyRw6mf6Bm7+LiHQU2mTUTnxbf0mUq+KFI1P4hq5ZJCIdkEYI7aHuKA1bfsG2ihHUDP0E/TQ6EJEOSCOEduAreZloV8HfjnyCb+i8AxHpoFQQ2kFewjQmFd7F8Ytnktytm9dxRETOizYZXaj6Kn749ttsshT+MWOG12lERM6bCsKFaKim6q/pzN4/huE536ZPQoLXiUREzps2GV2AhtpK1nw8jMO1qXxD5x2ISAenEcIFeHnVBj46OJ2GkSM1OhCRDi8kIwQzm2lm28ws38zua2b5N81si5ltMLPlZjYkFLkuREPxSxxa+/84YPBvOu9ARDqBoBcEM4sEHgdmAVnAbWaW1aTZR0COc240sAj4z2DnuiC+eqrev4OZSa/S/5Isemt0ICKdQChGCBOAfOdcgXOuFngWuLFxA+fcm865ysDkSiAtBLnO3+6/kNhQwu/Kr+aeG27wOo2ISLsIRUFIBYoaTRcH5p3NvwKvNLfAzOabWa6Z5ZaWlrZjxHPgfFSt/yEba/oRMfbLJMXHe5NDRKSdhdVRRmb2eSAHeKS55c65hc65HOdcTnJycmjDBfj2/JX4yh08dfRq7rlikicZRESCIRRHGZUAgxpNpwXmncbMrgX+HfiEc64mBLnOnXMcXfV9Gmp70yvtFnrExnqdSESk3YRihLAaGGZmGWYWA8wFFjduYGZjgd8Cc5xzB0OQ6bz4Sl4iqT6fJUemcNfsf/I6johIuwr6CME5V29mdwJLgUjgSefcZjNbAOQ65xbj30SUCDxvZgB7nHNzgp3tnDjHkQ++S0RdEsWDb6NnXJzXiURE2lVITkxzzi0BljSZd3+jx9eGIseF8O17nT51efz5yA18/TPhVatERNqDzlRuo1WVKXxYeh3V6V+kh0YHItIJqSC00YMr1/NR7TQKrtfoQEQ6JxWENihYfDOX7I/h2kl3kRgT43UcEZGgUEFoha+mjPiPVzA1djTX5uR4HUdEJGhUEFqx6L015O3+CgljLuUmnXcgIp1YWJ2pHG58FUVsXvUWpRFRfE37DkSkk1NBaEHx0i/wjdTHSL90JN2070BEOjkVhLMpzyOt+h3eqRjD167XFU1FpPNTQTiL/av+nSoXxf6x3ydBowMR6QJUEJrRUL6d5NK/83TlJP55wjSv44iIhIQKQjN2Lf83nDNcv3nERelALBHpGlQQmmg4uov0qtdYcWwcX5z9Oa/jiIiEjApCE/nLv4Hh2D7wK8RFR3sdR0QkZFQQGmmoKGbo8ZdZdWwsn7/+C17HEREJKRWERt7eupZtVYPZnjJfowMR6XK0xzTAOcf31h3kwPG72Db7X7yOIyISchohBLz3xn9RfmAL/37VVcTqyCIR6YL0yQc01FZy2d4HeDp5KKOzH/E6joiIJ1QQgKffeo+P99xOr0tHMS4y0us4IiKe6PIFoaG+ng0ffki968PXrp/ndRwREc+EZB+Cmc00s21mlm9m9zWzfIqZrTWzejP7VCgynfDRS3fwlf5/ZMRlo4jRvgMR6cKCXhDMLBJ4HJgFZAG3mVlWk2Z7gHnAM8HO05irO86wY3+h3uBfZ+p+ByLStYVihDAByHfOFTjnaoFngRsbN3DOFTrnNgC+EOQ5aeOKH9MzsoLCEfcRrX0HItLFhaIgpAJFjaaLA/M8VV9bSUrRf/NhXSbTJ33J6zgiIp7rUOchmNl8M8s1s9zS0tILeq6VS+4jOaKcQ/3mExXRoV4GEZGgCMUnYQkwqNF0WmDeOXPOLXTO5TjncpKTk887UH1dNZeU/x+7q9OYMf2b5/08IiKdSSgKwmpgmJllmFkMMBdYHIJ+z+r9Jf9On+gytvT5V6J0ZJGICBCCguCcqwfuBJYCecBzzrnNZrbAzOYAmNl4MysGbgV+a2abg5Wnvq6WS478kZKaAUyf/YNgdSMi0uGE5Ouxc24JsKTJvPsbPV6Nf1NS0C15789cF3WMd3rcS6pGByIiJ3W5vamRqVfwlYjHuPr6B72OIiISVrrcV+TZw4cze/hwr2OIiISdLjdCEBGR5qkgiIgIoIIgIiIBKggiIgKoIIiISIAKgoiIACoIIiISoIIgIiIAmHPO6wznxcxKgd3nuXpf4FA7xuno9HqcTq/HKXotTtcZXo8hzrlmLxfdYQvChTCzXOdcjtc5woVej9Pp9ThFr8XpOvvroU1GIiICqCCIiEhAVy0IC70OEGb0epxOr8cpei1O16lfjy65D0FERM7UVUcIIiLShAqCiIgAXbAgmNlMM9tmZvlmdp/XebxiZoPM7E0z22Jmm83sHq8zhQMzizSzj8zsJa+zeM3MksxskZltNbM8M7vC60xeMbNvBP5PNpnZn80szutMwdClCoKZRQKPA7OALOA2M8vyNpVn6oFvOeeygMuBr3fh16Kxe4A8r0OEiV8BrzrnRgBj6KKvi5mlAncDOc65UUAkMNfbVMHRpQoCMAHId84VOOdqgWeBGz3O5Ann3D7n3NrA42P4/9lTvU3lLTNLA2YDv/c6i9fMrCcwBfgDgHOu1jlX5m0qT0UB8WYWBSQAez3OExRdrSCkAkWNpovp4h+CAGaWDowFVnmbxHO/BL4D+LwOEgYygFLgfwOb0H5vZt28DuUF51wJ8CiwB9gHlDvnXvM2VXB0tYIgTZhZIvBX4N+cc0e9zuMVM7sBOOicW+N1ljARBVwG/I9zbixwHOiS+9zMrBf+LQkZwECgm5l93ttUwdHVCkIJMKjRdFpgXpdkZtH4i8HTzrkXvM7jscnAHDMrxL8p8Roz+5O3kTxVDBQ7506MGhfhLxBd0bXALudcqXOuDngBmORxpqDoagVhNTDMzDLMLAb/jqHFHmfyhJkZ/u3Dec65n3udx2vOue8559Kcc+n43xdvOOc65bfAtnDO7QeKzOziwKxpwBYPI3lpD3C5mSUE/m+m0Ul3sEd5HSCUnHP1ZnYnsBT/kQJPOuc2exzLK5OBLwAbzWxdYN73nXNLPMwk4eUu4OnAl6cC4Ese5/GEc26VmS0C1uI/Ou8jOuklLHTpChERAbreJiMRETkLFQQREQFUEEREJEAFQUREABUEEREJUEEQERFABUFERAJUEETamZlNMrMFXucQOVc6MU1ERACNEETanZk9b2ZXeZ1D5FypIIi0v1HABq9DiJwrFQSRdhS4126Mc67c6ywi50oFQaR9jaTrXiZaOjgVBJH2dSnaXCQdlAqCSPtSQZAOS4ediogIoBGCiIgEqCCIiAiggiAiIgEqCCIiAqggiIhIgAqCiIgAKggiIhLw/wEXk2OVh69dkwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = plt.figure()\n",
    "\n",
    "\n",
    "i = 0\n",
    "plt.plot(val_preds[i], label='LLCP', color='teal')\n",
    "plt.plot(lstsq_val_preds[i], label='least squares', linestyle='--', color='gray')\n",
    "plt.plot(val_outputs[i], label='true', linestyle='-.', color='orange')\n",
    "w, h = 8, 3.5\n",
    "plt.xlabel(r'$i$')\n",
    "plt.ylabel(r'$y_i$')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
